from transformers import AutoModel, AutoTokenizer
import streamlit as st


st.set_page_config(
    page_title="CoHelperGLM",
    page_icon=":robot:",
    layout='wide'
)


@st.cache_resource
def get_model():
    tokenizer = AutoTokenizer.from_pretrained("CODEGEEX2", trust_remote_code=True)
    # CodeGeeX2-6B模型
    model = AutoModel.from_pretrained("CODEGEEX2", trust_remote_code=True, device='cuda')
    # CodeGeeX2-6B-int4模型
    # model = AutoModel.from_pretrained("CODEGEEX2", trust_remote_code=True).quantize(4).to("cuda")

    model = model.eval()
    return tokenizer, model

 
tokenizer, model = get_model()

st.title("Prompts in, Bugs out")

max_length = st.sidebar.slider(
    'max_length', 0, 2560, 512, step=1
)
top_k = st.sidebar.slider(
    'top_k', 1, 10, 1, step=1
)
if int(top_k) == 1:
    top_p = st.sidebar.slider(
    'top_p', 0.0, 1.0, 0.8, step=0.01, disabled=True
    )
    temperature = st.sidebar.slider(
        'temperature', 0.0, 1.0, 0.95, step=0.01, disabled=True
    )

else:
    top_p = st.sidebar.slider(
        'top_p', 0.0, 1.0, 0.8, step=0.01
    )
    temperature = st.sidebar.slider(
        'temperature', 0.0, 1.0, 0.95, step=0.01
    )


genre = st.sidebar.radio(
    "语言选择：",
    ('Python', 'HTML', 'Shell', "Go", "C++", "Java", "JavaScript"))



if genre == 'Shell':
    lan = "shell"
    p = "# language: Shell"
    prefix = "#"
elif genre == 'Python':
    lan = "Python"
    p = "# language: Python"
    prefix = "#"
elif genre == 'HTML':
    lan = "HTML"
    p = "<!--language: HTML-->"
elif genre == 'C++':
    lan = "C++"
    p = "// language: C++"
    prefix = "//"
elif genre == 'Go':
    lan = "Go"
    p = "// language: Go"
    prefix = "//"
elif genre == 'Java':
    lan = "Java"
    p = "// language: Java"
    prefix = "//"
elif genre == 'JavaScript':
    lan = "JavaScript"
    p = "// language: JavaScript"
    prefix = "//"
else:
    lan = "Python"
    p = "# language: Python"
    prefix = "#"

if 'history' not in st.session_state:
    st.session_state.history = []

if 'past_key_values' not in st.session_state:
    st.session_state.past_key_values = None

for i, (query, response) in enumerate(st.session_state.history):
    with st.chat_message(name="user", avatar="user"):
        st.text(query)
    with st.chat_message(name="assistant", avatar="assistant"):
        st.markdown("```{}\n".format(lan.lower())+response+"\n```")
with st.chat_message(name="user", avatar="user"):
    input_placeholder = st.empty()
with st.chat_message(name="assistant", avatar="assistant"):
    message_placeholder = st.empty()

prompt_text = st.text_area(label="用户命令输入",
                           height=100,
                           placeholder="用户命令输入,例如：写一个冒泡排序算法")

button = st.button("发送", key="predict")



def pre_text(s, prefix):
    res = []
    s = s.split('\n')

    for line in s:
        if line:
            if not line[0].startswith(prefix):
                res.append(prefix + ' ' + line)
            else:
                res.append(line)
    return '\n'.join(res)

def html_text(s):
    res = []
    s = s.split('\n')
    for line in s:
        if line:
            if len(line) >= 4:
                if line[:4]!="<!--" or line[-3:]!="-->":
                    res.append('<!--' + line + '-->')
                else:
                    res.append(line)
            else:
                res.append(line)
    return '\n'.join(res)
if button:
    if lan.lower() == "html":
        prompt_t = html_text(prompt_text)
    else:
        prompt_t = pre_text(prompt_text, prefix)
    prompt_t = "{}\n{}\n".format(p, prompt_t)
    input_placeholder.text(prompt_text)
    history, past_key_values = st.session_state.history, st.session_state.past_key_values


    for response, history in model.stream_chat(tokenizer, prompt_t, [],
                                                                past_key_values=None,
                                                                max_length=max_length, top_p=top_p, top_k=top_k,
                                                                temperature=temperature,
                                                                return_past_key_values=False):
        message_placeholder.markdown("```{}\n".format(lan.lower())+response+"\n```")

    st.session_state.history.append((prompt_text, response))
    st.session_state.past_key_values = past_key_values